/*
* Copyright (C) 2025 ByteDance and/or its affiliates
*
* This program is free software: you can redistribute it and/or modify
* it under the terms of the GNU General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* This program is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
* GNU General Public License for more details.
*
* You should have received a copy of the GNU General Public License
* along with this program.  If not, see <https://www.gnu.org/licenses/>.
*/

#include "bytedock/core/data.h"
#include "bytedock/ext/logging.h"

#include "test_lib.h"

namespace bytedock {

/**
 * Geometry and coordinates of the ligand are generated by the following code:
 * ```python
 * from bytevs.op import conformer_op
 * 
 * from bytedock.geometry.geometry import LigandGeometry
 * from bytedock.parser.ligand import LigandParser
 * 
 * from bdock_opt.app.prepare_ligand import calculate_molecular_weight
 * from bdock_opt.utils import write_to_json_file
 * 
 * def run():
 *     smiles = 'c1c(C(N)C2CC2)cc(C(O)=O)cc1'
 *     rkmol, _ = conformer_op.generate_conformer(smiles, n_confs=1, random_seed=59)
 *     data = LigandParser(rkmol).get_data()
 *     geo = LigandGeometry(data['bond_index'], data['is_rotatable'])
 *     data['geometry'] = {
 *         'num_atoms': geo.num_atoms.tolist(),  # numpy.int64 -> int
 *         'permute_index': geo.permute_index,
 *         'frag_traverse_levels': geo.frag_traverse_levels,
 *         'edge_to_bond': geo.frag_tree.edge_to_bond,
 *         'frag_split_index': geo.frag_split_index,
 *     }
 *     for key in ('bond_index', 'is_rotatable'):
 *         if key in data: del data[key]
 *     data['ffdata']['molecular_weight'] = calculate_molecular_weight(data['mapped_smiles'])
 *     for item in data['xyz'][0]:
 *         for i in range(3):
 *             item[i] = float('%.3f' % item[i])
 *     write_to_json_file(data, 'fake_ligand.json')
 * ```
 */

class ConformerTest : public testing::Test {
protected:
    void SetUp() override {
        sink_ = enable_global_logging("-", 1);

        auto ls = std::make_shared<std::stringstream>();
        (*ls) << R"ligand_json({
    "xyz": [[
        [ 0.686,  1.767,  0.214],
        [ 0.052,  0.590, -0.217],
        [-1.458,  0.562, -0.372],
        [-1.863,  0.238, -1.736],
        [-2.083, -0.441,  0.560],
        [-3.510, -0.265,  1.024],
        [-2.380, -0.076,  1.995],
        [ 0.851, -0.534, -0.493],
        [ 2.245, -0.483, -0.343],
        [ 3.028, -1.703, -0.653],
        [ 4.349, -1.542, -0.456],
        [ 2.573, -2.761, -1.046],
        [ 2.853,  0.701,  0.087],
        [ 2.073,  1.823,  0.365],
        [ 0.098,  2.655,  0.438],
        [-1.858,  1.559, -0.148],
        [-1.398,  0.883, -2.373],
        [-1.525, -0.690, -1.981],
        [-1.762, -1.460,  0.366],
        [-4.122, -1.155,  1.116],
        [-4.054,  0.612,  0.690],
        [-2.226, -0.837,  2.752],
        [-2.164,  0.932,  2.332],
        [ 0.389, -1.463, -0.824],
        [ 4.729, -2.414, -0.694],
        [ 3.932,  0.759,  0.209],
        [ 2.546,  2.743,  0.700]
    ]],
    "ffdata": {
        "FF_Bonds_atomidx": [],
        "FF_Bonds_k": [],
        "FF_Bonds_length": [],
        "FF_Angles_atomidx": [],
        "FF_Angles_k": [],
        "FF_Angles_angle": [],
        "FF_ProperTorsions_atomidx": [],
        "FF_ProperTorsions_periodicity": [],
        "FF_ProperTorsions_k": [],
        "FF_ProperTorsions_phase": [],
        "FF_ImproperTorsions_atomidx": [],
        "FF_ImproperTorsions_periodicity": [],
        "FF_ImproperTorsions_k": [],
        "FF_ImproperTorsions_phase": [],
        "FF_Nonbonded14_atomidx": [],
        "FF_NonbondedAll_atomidx": [],
        "hydrophobic_atomidx": [],
        "cation_atomidx": [],
        "anion_atomidx": [],
        "piring5_atomidx": [],
        "piring6_atomidx": [],
        "tstrain_atomidx": [],
        "tstrain_params_pose_selection": [],
        "hbonddon_charged_atomidx": [],
        "hbonddon_neut_atomidx": [],
        "hbondacc_charged_atomidx": [],
        "hbondacc_neut_atomidx": [],
        "partial_charges": [],
        "FF_vdW_paraidx": [],
        "rotatable_bond_index": [],
        "bond_index": [],
        "atomic_numbers": [],
        "hydrophobic_group": []
    },
    "geometry": {
        "num_atoms": 27,
        "permute_index": [
             0,  1,  7,  8, 12, 13, 14, 23, 25, 26,
             2, 15,  3, 16, 17,  4,  5,  6, 18, 19,
            20, 21, 22,  9, 11, 10, 24
        ],
        "frag_split_index": [10, 12, 15, 23, 25],
        "frag_traverse_levels": [
            {
                "0": {
                    "children": [1, 4],
                    "parent": []
                }
            },
            {
                "1": {
                    "children": [2, 3],
                    "parent": [0]
                },
                "4": {
                    "children": [5],
                    "parent": [0]
                }
            },
            {
                "2": {
                    "children": [],
                    "parent": [1]
                },
                "3": {
                    "children": [],
                    "parent": [1]
                },
                "5": {
                    "children": [],
                    "parent": [4]
                }
            }
        ],
        "edge_to_bond": {
            "0->1": [1, 2],
            "1->0": [2, 1],
            "0->4": [8, 9],
            "4->0": [9, 8],
            "1->2": [2, 3],
            "2->1": [3, 2],
            "1->3": [2, 4],
            "3->1": [4, 2],
            "4->5": [9, 10],
            "5->4": [10, 9]
        }
    }
})ligand_json";
        ligand_ = std::make_shared<free_ligand>();
        ligand_->parse(ls);

        auto rs = std::make_shared<std::stringstream>();
        (*rs) << R"receptor_json({
    "xyz": [
        [28.800, 40.380,  5.300],
        [15.410, 27.400,  6.380],
        [15.100, 27.430,  5.060],
        [14.600, 26.630,  4.840],
        [16.630, 31.420,  1.350],
        [17.020, 30.230,  0.680],
        [17.570, 30.450, -0.070],
        [ 2.680, 30.610, 11.500],
        [ 1.600, 31.410, 10.900],
        [ 1.520, 32.290, 11.390],
        [ 1.820, 31.600,  9.930],
        [ 0.730, 30.910, 10.960]
    ],
    "ffdata": {
        "FF_Bonds_atomidx": [],
        "FF_Bonds_k": [],
        "FF_Bonds_length": [],
        "FF_Angles_atomidx": [],
        "FF_Angles_k": [],
        "FF_Angles_angle": [],
        "FF_ProperTorsions_atomidx": [],
        "FF_ProperTorsions_periodicity": [],
        "FF_ProperTorsions_k": [],
        "FF_ProperTorsions_phase": [],
        "FF_ImproperTorsions_atomidx": [],
        "FF_ImproperTorsions_periodicity": [],
        "FF_ImproperTorsions_k": [],
        "FF_ImproperTorsions_phase": [],
        "FF_Nonbonded14_atomidx": [],
        "FF_NonbondedAll_atomidx": [],
        "hydrophobic_atomidx": [],
        "cation_atomidx": [],
        "anion_atomidx": [],
        "piring5_atomidx": [],
        "piring6_atomidx": [],
        "hbonddon_charged_atomidx": [],
        "hbonddon_neut_atomidx": [],
        "hbondacc_charged_atomidx": [],
        "hbondacc_neut_atomidx": [],
        "partial_charges": [],
        "FF_vdW_sigma": [],
        "FF_vdW_epsilon": [],
        "rotatable_bond_index": [],
        "bond_index": [],
        "atomic_numbers": []
    },
    "geometry": {
        "num_atoms": 12,
        "num_rotatable_bonds": 3,
        "data": {
            "2": {
                "fragments": [[2, 3], [5, 6]],
                "rot_bond_1": [1, 4],
                "rot_bond_2": [2, 5],
                "rot_bond_order": [0, 1]
            },
            "4": {
                "fragments": [[8, 9, 10, 11]],
                "rot_bond_1": [7],
                "rot_bond_2": [8],
                "rot_bond_order": [2]
            }
        }
    }
})receptor_json";
        receptor_ = std::make_shared<torsional_receptor>();
        receptor_->parse(rs);
    }

    void TearDown() override {
        disable_sink(sink_);
    }

    boost::shared_ptr<sink_t> sink_;
    std::shared_ptr<free_ligand> ligand_;
    std::shared_ptr<torsional_receptor> receptor_;
};

TEST_F(ConformerTest, ApplyReceptorParameters) {
    std::vector<param_t> torsions = {
        kMathPi / 2_r,   // atom#1 -> atom#2 with 3
        kMathPi,         // atom#4 -> atom#5 with 6
        -kMathPi / 3_r,  // atom#7 -> atom#8 with 9,10,11
    };
    auto new_xyz = receptor_->apply_parameters(torsions);
    const molecule_pose ref_xyz = {
        {28.800000, 40.380000,  5.300000},
        {15.410000, 27.400000,  6.380000},
        {15.100000, 27.430000,  5.060000},
        {14.245493, 27.873225,  4.951510},
        {16.630000, 31.420000,  1.350000},
        {17.020000, 30.230000,  0.680000},
        {16.646023, 29.472904,  1.127602},
        { 2.680000, 30.610000, 11.500000},
        { 1.600000, 31.410000, 10.900000},
        { 0.894953, 31.602039, 11.597803},
        { 1.975777, 32.290713, 10.570551},
        { 1.187745, 30.907720, 10.133019}
    };
    for (size_t i = 0; i < ref_xyz.size(); ++i) {
        EXPECT_NEAR(ref_xyz[i].xyz[0], new_xyz[i].xyz[0], 5e-7);
        EXPECT_NEAR(ref_xyz[i].xyz[1], new_xyz[i].xyz[1], 5e-7);
        EXPECT_NEAR(ref_xyz[i].xyz[2], new_xyz[i].xyz[2], 5e-7);
    }
}

TEST_F(ConformerTest, TranslateLigand) {
    auto& ligand_xyz = ligand_->get_pose(0);
    atom_position old_center = calculate_geometric_center(ligand_xyz);
    atom_position new_center = {10_r, 11_r, 12_r};
    matrix_3x3 orientation = {1_r, 0_r, 0_r, 0_r, 1_r, 0_r, 0_r, 0_r, 1_r};
    std::vector<param_t> torsions(ligand_->num_torsions(), 0_r);
    auto new_xyz = ligand_->apply_parameters(ligand_xyz, new_center,
                                             orientation, torsions);
    for (size_t i = 0; i < ligand_xyz.size(); ++i) {
        EXPECT_NEAR(ligand_xyz[i].xyz[0] - old_center.xyz[0],
                    new_xyz[i].xyz[0] - new_center.xyz[0], 1e-7);
        EXPECT_NEAR(ligand_xyz[i].xyz[1] - old_center.xyz[1],
                    new_xyz[i].xyz[1] - new_center.xyz[1], 1e-7);
        EXPECT_NEAR(ligand_xyz[i].xyz[2] - old_center.xyz[2],
                    new_xyz[i].xyz[2] - new_center.xyz[2], 1e-7);
    }
}

TEST_F(ConformerTest, RotateLigand) {
    auto& ligand_xyz = ligand_->get_pose(0);
    atom_position old_center = calculate_geometric_center(ligand_xyz);
    atom_position new_center = {0_r, 0_r, 0_r};
    matrix_3x3 orientation = {-1_r, 0_r, 0_r, 0_r, -1_r, 0_r, 0_r, 0_r, 1_r};
    std::vector<param_t> torsions(ligand_->num_torsions(), 0_r);
    auto new_xyz = ligand_->apply_parameters(ligand_xyz, new_center,
                                             orientation, torsions);
    for (size_t i = 0; i < ligand_xyz.size(); ++i) {
        EXPECT_NEAR(old_center.xyz[0] - ligand_xyz[i].xyz[0], new_xyz[i].xyz[0], 1e-7);
        EXPECT_NEAR(old_center.xyz[1] - ligand_xyz[i].xyz[1], new_xyz[i].xyz[1], 1e-7);
        EXPECT_NEAR(ligand_xyz[i].xyz[2] - old_center.xyz[2], new_xyz[i].xyz[2], 1e-7);
    }
}

TEST_F(ConformerTest, MutateLigand) {
    atom_position new_center = {-1_r, -2_r, -3_r};
    matrix_3x3 orientation = {
        std::sqrt(2) / 4_r, -std::sqrt(6) / 8_r - std::sqrt(3) / 4_r, std::sqrt(2) / 8_r - 0.75_r,
        std::sqrt(6) / 4_r, 0.25_r - std::sqrt(18) / 8_r, std::sqrt(6) / 8_r + std::sqrt(3) / 4_r,
        -std::sqrt(2) / 2_r, -std::sqrt(6) / 4_r, std::sqrt(2) / 4_r
    };  // Euler Angles ZYX => (pi/3, pi/4, -pi/3)
    std::vector<param_t> torsions = {
        kMathPi/2,  // fragment#0 -> fragment#1
        -kMathPi,   // fragment#1 -> fragment#2
        0_r,        // fragment#1 -> fragment#3
        kMathPi,    // fragment#0 -> fragment#4
        -kMathPi/2  // fragment#4 -> fragment#5
    };
    auto new_xyz = ligand_->apply_parameters(ligand_->get_pose(0), new_center,
                                             orientation, torsions);
    const molecule_pose ref_xyz = {
        {-2.154209, -1.958485, -4.511250},
        {-1.261266, -2.335375, -3.494563},
        {-1.685584, -3.366784, -2.464486},
        {-1.049737, -3.142559, -1.170300},
        {-3.177796, -3.359105, -2.268266},
        {-3.890403, -4.602241, -1.789304},
        {-4.079913, -4.130074, -3.202317},
        { 0.010293, -1.735018, -3.468816},
        { 0.379463, -0.784787, -4.432720},
        { 1.735818, -0.192449, -4.348892},
        { 2.459659, -0.683096, -3.326595},
        { 2.193265,  0.649541, -5.098981},
        {-0.527274, -0.426520, -5.435662},
        {-1.791783, -1.013204, -5.472913},
        {-3.146909, -2.401913, -4.560062},
        {-1.372945, -4.361594, -2.806240},
        {-1.776214, -3.091916, -0.457779},
        {-0.475629, -3.947074, -0.928210},
        {-3.552369, -2.413113, -1.889015},
        {-4.709986, -4.474396, -1.091418},
        {-3.302540, -5.498038, -1.620227},
        {-5.032260, -3.684236, -3.467175},
        {-3.629927, -4.714830, -3.997214},
        { 0.723404, -2.002182, -2.690264},
        { 2.918414, -1.457612, -3.715306},
        {-0.258596,  0.308153, -6.191015},
        {-2.496645, -0.733824, -6.252317}
    };
    for (size_t i = 0; i < ref_xyz.size(); ++i) {
        EXPECT_NEAR(ref_xyz[i].xyz[0], new_xyz[i].xyz[0], 5e-7);
        EXPECT_NEAR(ref_xyz[i].xyz[1], new_xyz[i].xyz[1], 5e-7);
        EXPECT_NEAR(ref_xyz[i].xyz[2], new_xyz[i].xyz[2], 5e-7);
    }
}

}
